Online system identification in Duffing oscillator by free energy minimisation

This project considers a Duffing oscillator, a driven damped harmonic oscillator with a cubic nonlinearity in its spring stiffness component. State-space model description of the system:

$$\begin{align} m \frac{d^2 x(t)}{dt^2} + c \frac{d x(t)}{dt} + a x(t) + b x^3(t) =&\ u(t) + w(t) \\ y(t) =&\ x(t) + e(t) \end{align}$$

where $$\begin{align} m =&\ \text{mass} \\ c =&\ \text{damping} \\ a =&\ \text{linear stiffness} \\ b =&\ \text{nonlinear stiffness} \\ y(t) =&\ \text{observation (displacement)} \\ x(t) =&\ \text{state (displacement)} \\ u(t) =&\ \text{force} \\ v(t) =&\ \text{measurement noise} \\ w(t) =&\ \text{process noise} \end{align}$$

The process noise is a Wiener process, where the increment is Gaussian distributed $w(t) \sim \mathcal{N}(0, \tau^{-1}dt)$. The parameter $\tau$ represents the precision of the process. The measurement noise is also a Wiener process, $v(t) \sim \mathcal{N}(0, \xi^{-1}dt)$.

Forecasting experiment: 1-step ahead prediction error

In this notebook, we will perform a forecasting experiment. At each point in time, we will make a prediction for the next observation. That prediction will be evaluated and compared to a few benchmark methods.

Data

There is an electronic implementation of the Duffing oscillator on the Nonlinear System Identification Benchmark website: http://nonlinearbenchmark.org/#Silverbox. It's called Silverbox.

In [1]:
using Revise
using CSV
using DataFrames
In [2]:
using Plots
viz = true;
In [3]:
# Read data from CSV file
df = CSV.read("data/SNLS80mV.csv", ignoreemptylines=true)
df = select(df, [:V1, :V2])

# Sampling frequency
fs = 610.35

# Shorthand
input = df[:,1]
output = df[:,2]

# T_valime horizon
T_val = size(df, 1);
In [13]:
# Select training set
ix_trn = collect(40101:131072)
input_trn = input[ix_trn]
output_trn = output[ix_trn]
T_trn = length(ix_trn);

# Select validation set
ix_val = 1:40100 
input_val = input[ix_val]
output_val = output[ix_val]
T_val = length(ix_val);
In [104]:
# Plot example of signals
ss = 4
ix = 126000:ss:127500
p31 = Plots.plot(ix, output[ix], color="green", markersize=2, xlabel="time (t)", label="output (displacement)", linestyle=:dashdot)    
Plots.plot!(ix, input[ix], color="purple", markersize=2, xlabel="time (t)", label="input (control)", size=(1200,300), ylim=[-.16, .21], legend=:topright, tickfontsize=14, legendfontsize=12, ylabel="signal", guidefontsize=16)
Out[104]:
In [6]:
Plots.savefig(p31, "figures/input-output_seq1.png")
Plots.savefig(p31, "figures/input-output_seq1.pdf")

Solution steps

1. Discretize

I'm using a central difference for the second derivative and a forward difference for the first derivative. Let $w_t$ be a sample from $\mathcal{N}(0, \tau^{-1})$. The state transition can now be wriken as the following discrete-time system:

$$\begin{align} m (x_{t+1} - 2x_{t} + x_{t-1}) + c (x_{t+1} - x_{t}) + a x_t + b x_t^3 =&\ u_t + w_t \end{align}$$

Re-writing this as a function of $x_{t+1}$ yields: $$\begin{align} % (m + c) x_{t+1}&\ + (-2m - c + a) x_{t} + bx_t^3 + m x_{t-1} = u_t + w_t \\ x_{t+1}&\ = \frac{2m + c - a}{m + c} x_{t} + \frac{-b}{m + c}x_t^3 + \frac{-m}{m + c} x_{t-1} + \frac{1}{m + c} u_t + \frac{1}{m + c} w_t \, . \end{align}$$

2. Substitute variables and reduce order

I can cast the above system into matrix form:

$$ \underbrace{\begin{bmatrix} x_{t+1} \\ x_{t} \end{bmatrix}}_{z_t} = \underbrace{\begin{bmatrix} 0 & 0 \\ 1 & 0 \end{bmatrix}}_{S} \underbrace{\begin{bmatrix} x_{t} \\ x_{t-1} \end{bmatrix}}_{z_{t-1}} + \underbrace{\begin{bmatrix} 1 \\ 0 \end{bmatrix}}_{s} g(\theta, z_{t-1}) + \begin{bmatrix} 1 \\ 0 \end{bmatrix} \eta u_t + \begin{bmatrix} 1 \\ 0 \end{bmatrix} \tilde{w}_t \, ,$$

where

$$\begin{align} \theta_1 = \frac{2m+c-a}{m+c} \ , \quad \theta_2 = \frac{-b}{m+c} \ , \quad \theta_3 = \frac{-m}{m+c} \ , \quad \eta = \frac{1}{m+c} \ , \quad \gamma^{-1} = \frac{\tau^{-1}}{(m+c)^2} \, , \end{align}$$

with $g(\theta, z_{t-1}) = \theta_1 x_t + \theta_2 x_t^3 + \theta_3 x_{t-1}$ and $\tilde{w}_t \sim \mathcal{N}(0, \gamma^{-1})$. In total, I have five unknowns $m,c,a,b,\tau$ and five equations. I can invert the mapping between $\phi = (m, c, a, b, \tau)$ and $\psi = (\theta_1, \theta_2, \theta_3, \eta, \gamma)$ to recover MAP estimates for the physical parameters. An additional advantage of variable substitution is that it allows for more freedom in choosing priors.

The system is now a nonlinear autoregressive process:

$$z_t = f(\theta, z_{t-1}, \eta, u_t) + \tilde{w}_t$$

where $f(\theta, z_{t-1}, \eta, u_t) = Sz_{t-1} + s g(\theta, z_{t-1}) + s \eta u_t$. Note that the states are two-dimensional now.

3. Convert to Gaussian probability

Integrating out $\tilde{w}_t$ and $v_t$ produces a Gaussian state transition node:

$$\begin{align} z_t \sim&\ \mathcal{N}(f(\theta, z_{t-1}, \eta, u_t), V) \\ y_t \sim&\ \mathcal{N}(s^{\top} z_t, \xi^{-1}) \, , \end{align}$$

where $V = \begin{bmatrix} \gamma^{-1} & 0 \\ 0 & \epsilon \end{bmatrix}$ and $W = V^{-1} = \begin{bmatrix} \gamma & 0 \\ 0 & \epsilon^{-1} \end{bmatrix}$.

4. Approximating the nonlinearity

The nonlinearity is approximated using a first-order Taylor expansion. The work here revolves around working out the expectations for $g(x,\theta)$:

$$ g(\theta, x) = g(m_{\theta}, m_x) + J_{x}(m_{\theta}, m_x)^{\top}(x - m_x) + J_{\theta}(m_{\theta}, m_x)^{\top}(\theta - m_{\theta}) \, ,$$

where $J_x$ denotes the partial derivative of $g$ with respect to $x$ and $J_{\theta}$ w.r.t. $\theta$. Note that our current $g$ is linear in $\theta$ and one could argue that the approximation is unnecessary. However, this form is more general and the first-order T_valaylor is exact anyway.

5. Choose priors

We know that mass $m$ and process precision $\gamma$ are strictly positive parameters and that the damping and stiffness coefficients can be both positive and negative. By examing the nonlinear transform $\psi = G(\phi)$, we realize that $\theta_1$, $\theta_2$, $\theta_3$ and $\eta$ can be both positive and negative, but $\gamma$ can only be positive. As such, we choose the following priors:

$$\begin{align} \theta \sim \text{Normal}(m^{0}_{\theta}, V^{0}_{\theta}) \ , \quad \eta \sim \text{Normal}(m^{0}_{\eta}, v^{0}_{\eta}) \ , \quad \gamma \sim \text{Gamma}(a^{0}_\gamma, b^{0}_\gamma) \, . \end{align}$$

6. Choose recognition model

We do not introduce any independencies; the recognition model follows the generative model:

$$\begin{align} q(\theta) \sim \text{Normal}(m_{\theta}, V_{\theta}) \ , \quad q(\eta) \sim \text{Normal}(m_{\eta}, v_{\eta}) \ , \quad q(\gamma) \sim \text{Gamma}(a_\gamma, b_\gamma) \, . \end{align}$$

Implementation

The procedure described above was implemented using ForneyLab.jl with a custom node called "NLARX". It contains a Nonlinear Latent Autoregressive model with eXogenous input to model the state transition.

In [7]:
using LinearAlgebra
using ForneyLab
using ForneyLab: unsafeMean, unsafeCov, unsafeVar, unsafePrecision
using ProgressMeter

include("NLARX-node/NLARX.jl")
include("NLARX-node/util.jl")
using .NLARX
In [56]:
# System identification graph
graph1 = FactorGraph()

# Static parameters
@RV θ ~ GaussianMeanPrecision(placeholder(:m_θ, dims=(3,)), placeholder(:w_θ, dims=(3,3)))
@RV η ~ GaussianMeanPrecision(placeholder(:m_η), placeholder(:w_η))
@RV γ ~ Gamma(placeholder(:a_γ), placeholder(:b_γ))
@RV ξ ~ Gamma(placeholder(:a_ξ), placeholder(:b_ξ))

# Nonlinearity
g(θ, x) = θ[1]*x[1] + θ[2]*x[1]^3 + θ[3]*x[2]

# State prior
@RV z_tmin1 ~ GaussianMeanPrecision(placeholder(:m_z, dims=(2,)), placeholder(:w_z, dims=(2, 2)), id=:z_tmin1)

# Autoregressive node
@RV z_t ~ NLatentAutoregressiveX(θ, z_tmin1, η, placeholder(:u_t), γ, g=g, id=:z_t)

# Specify likelihood
@RV y_t ~ GaussianMeanPrecision(dot([1. , 0.], z_t), ξ, id=:y_t)

# Placeholder for observation
placeholder(y_t, :y_t)

# Draw time-slice subgraph
ForneyLab.draw(graph1)
G 15276648742828746167 placeholder_b_ξ 10170369415506725436 Gam gamma_2 10170369415506725436--15276648742828746167 b_ξ 1 out 3 b 7166982443120167684 placeholder_a_ξ 10170369415506725436--7166982443120167684 a_ξ 1 out 2 a 382038900276828701 placeholder_a_γ 9046691609987470114 placeholder_y_t 17808457127763077347 𝒩 y_t 9046691609987470114--17808457127763077347 y_t 1 out 1 out 11220003347629146243 clamp_1 15212234983663415721 placeholder_w_θ 3105675241445622358 𝒩 gaussianmeanprecision_2 4784486004710606138 placeholder_m_η 3105675241445622358--4784486004710606138 m_η 1 out 2 m 279114166509090701 placeholder_w_η 3105675241445622358--279114166509090701 w_η 1 out 3 w 17808457127763077347--10170369415506725436 ξ 1 out 3 w 1727098722854011686 dot dotproduct_1 17808457127763077347--1727098722854011686 variable_1 1 out 2 m 7548977792797681655 Gam gamma_1 7548977792797681655--382038900276828701 a_γ 1 out 2 a 4409807267531891835 placeholder_b_γ 7548977792797681655--4409807267531891835 b_γ 1 out 3 b 9189604046227227543 placeholder_m_θ 1727098722854011686--11220003347629146243 clamp_1 1 out 3 in2 2117267651900425112 NLARX z_t 1727098722854011686--2117267651900425112 z_t 1 y 2 in1 14584965605474415167 placeholder_m_z 7736205506201030281 𝒩 z_tmin1 7736205506201030281--14584965605474415167 m_z 1 out 2 m 2104913045585371225 placeholder_w_z 7736205506201030281--2104913045585371225 w_z 1 out 3 w 2117267651900425112--3105675241445622358 η 1 out 4 η 2117267651900425112--7548977792797681655 γ 1 out 6 γ 2117267651900425112--7736205506201030281 z_tmin1 1 out 2 x 1392544946289600915 placeholder_u_t 2117267651900425112--1392544946289600915 u_t 1 out 5 u 11709254894864042817 𝒩 gaussianmeanprecision_1 2117267651900425112--11709254894864042817 θ 1 out 3 θ 11709254894864042817--15212234983663415721 w_θ 1 out 3 w 11709254894864042817--9189604046227227543 m_θ 1 out 2 m
In [73]:
# Specify recognition model
q1 = PosteriorFactorization(z_t, z_tmin1, θ, η, γ, ξ, ids=[:z_t, :z_tmin1, :θ, :η, :γ, :ξ])
algo1 = variationalAlgorithm(q1, free_energy=true)

# Compile inference algorithm
source_code1 = algorithmSourceCode(algo1, free_energy=true)
eval(Meta.parse(source_code1));
# println(source_code1)

Infer parameters on training data

In [58]:
# Inference parameters
num_iterations = 10

# Initialize marginal distribution and observed data dictionaries
data = Dict()
marginals = Dict()

# Initialize free energy tracking array
free_energy_trn = zeros(T_trn, num_iterations)

# Initialize arrays of parameterizations
params_z = (zeros(2,T_trn+1), repeat(.1 .*float(eye(2)), outer=(1,1,T_trn+1)))
params_θ = (ones(3,T_trn+1), repeat(.1 .*float(eye(3)), outer=(1,1,T_trn+1)))
params_η = (2*ones(1,T_trn+1), 1e2 *ones(1,T_trn+1))
params_γ = (1e8*ones(1,T_trn+1), 1e3*ones(1,T_trn+1))
params_ξ = (1e8*ones(1,T_trn+1), 1e1*ones(1,T_trn+1))

# Start progress bar
p = Progress(T_trn, 1, "At time ")

# Perform inference at each time-step
for t = 1:T_trn

    # Update progress bar
    update!(p, t)

    # Initialize marginals
    marginals[:z_tmin1] = ProbabilityDistribution(Multivariate, GaussianMeanPrecision, m=params_z[1][:,t], w=params_z[2][:,:,t])
    marginals[:z_t] = ProbabilityDistribution(Multivariate, GaussianMeanPrecision, m=params_z[1][:,t], w=params_z[2][:,:,t])
    marginals[:θ] = ProbabilityDistribution(Multivariate, GaussianMeanPrecision, m=params_θ[1][:,t], w=params_θ[2][:,:,t])
    marginals[:η] = ProbabilityDistribution(Univariate, GaussianMeanPrecision, m=params_η[1][1,t], w=params_η[2][1,t])
    marginals[:γ] = ProbabilityDistribution(Univariate, Gamma, a=params_γ[1][1,t], b=params_γ[2][1,t])
    marginals[:ξ] = ProbabilityDistribution(Univariate, Gamma, a=params_ξ[1][1,t], b=params_ξ[2][1,t])
    
    data = Dict(:y_t => output_trn[t],
                :u_t => input_trn[t],
                :m_z => params_z[1][:,t],
                :w_z => params_z[2][:,:,t],
                :m_θ => params_θ[1][:,t],
                :w_θ => params_θ[2][:,:,t],
                :m_η => params_η[1][1,t],
                :w_η => params_η[2][1,t],
                :a_γ => params_γ[1][1,t],
                :b_γ => params_γ[2][1,t],
                :a_ξ => params_ξ[1][1,t],
                :b_ξ => params_ξ[2][1,t])

    # Iterate variational parameter updates
    for i = 1:num_iterations

        # Update parameters
        stepη!(data, marginals)
        stepθ!(data, marginals)
        
        # Update states
        stepz_t!(data, marginals)
        stepz_tmin1!(data, marginals)
        
        # Update noise
        stepγ!(data, marginals)
        stepξ!(data, marginals)
        
        # Compute free energy
        free_energy_trn[t, i] = freeEnergy(data, marginals)
        
    end

    # Store current parameterizations of marginals
    params_z[1][:,t+1] = unsafeMean(marginals[:z_t])
    params_z[2][:,:,t+1] = marginals[:z_t].params[:w]
    params_θ[1][:,t+1] = unsafeMean(marginals[:θ])
    params_θ[2][:,:,t+1] = marginals[:θ].params[:w]
    params_η[1][1,t+1] = unsafeMean(marginals[:η])
    params_η[2][1,t+1] = marginals[:η].params[:w]
    params_γ[1][1,t+1] = marginals[:γ].params[:a]
    params_γ[2][1,t+1] = marginals[:γ].params[:b]
    params_ξ[1][1,t+1] = marginals[:ξ].params[:a]
    params_ξ[2][1,t+1] = marginals[:ξ].params[:b]

end
At time 100%|███████████████████████████████████████████| Time: 0:04:00

K-step ahead prediction on validation data

In [59]:
# Prediction graph
graph2 = FactorGraph()

# Autoregressive node
@RV z_pred ~ NLatentAutoregressiveX(placeholder(:θ, dims=(3,)), placeholder(:z_tmin1, dims=(2,)), placeholder(:η), placeholder(:u_t), placeholder(:γ), g=g, id=:z_pred_t)

# Draw time-slice subgraph
ForneyLab.draw(graph2)

# Inference algorithm
q2 = PosteriorFactorization(z_pred, ids=[:z_pred])
algo2 = variationalAlgorithm(q2, free_energy=true)
source_code2 = algorithmSourceCode(algo2, free_energy=true)
eval(Meta.parse(source_code2));
# println(source_code2)
G 3842107477504927065 NLARX z_pred_t 15896722315181340526 placeholder_θ 3842107477504927065--15896722315181340526 θ 1 out 3 θ 5946173745968815176 placeholder_γ 3842107477504927065--5946173745968815176 γ 1 out 6 γ 7643093833392252060 placeholder_u_t 3842107477504927065--7643093833392252060 u_t 1 out 5 u 15434921048978955658 placeholder_z_tmin1 3842107477504927065--15434921048978955658 z_tmin1 1 out 2 x 10456180053600835186 placeholder_η 3842107477504927065--10456180053600835186 η 1 out 4 η 156718597102100609762 156718597102100609762--3842107477504927065 z_pred 1 y
begin

function stepz_pred!(data::Dict, marginals::Dict=Dict(), messages::Vector{Message}=Array{Message}(undef, 1))

messages[1] = ruleVariationalNLARXOutNPPPPP(g, nothing, ProbabilityDistribution(Multivariate, PointMass, m=data[:z_tmin1]), ProbabilityDistribution(Multivariate, PointMass, m=data[:θ]), ProbabilityDistribution(Univariate, PointMass, m=data[:η]), ProbabilityDistribution(Univariate, PointMass, m=data[:u_t]), ProbabilityDistribution(Univariate, PointMass, m=data[:γ]))

marginals[:z_pred] = messages[1].dist

return marginals

end

function freeEnergy(data::Dict, marginals::Dict)

F = 0.0

F += averageEnergy(NLatentAutoregressiveX, g, marginals[:z_pred], ProbabilityDistribution(Multivariate, PointMass, m=data[:z_tmin1]), ProbabilityDistribution(Multivariate, PointMass, m=data[:θ]), ProbabilityDistribution(Univariate, PointMass, m=data[:η]), ProbabilityDistribution(Univariate, PointMass, m=data[:u_t]), ProbabilityDistribution(Univariate, PointMass, m=data[:γ]))


return F

end

end # block
In [60]:
# Number of steps ahead
K = 1

# Initialize free energy tracking array
free_energy_pred = zeros(T_val, num_iterations)

# Initialize future state arrays
params_preds = (zeros(2, T_val - K, K+1), repeat(.1 .*float(eye(2)), outer=(1,1,T_val - K, K+1)))

# Start progress bar
p = Progress(T_val - K, 1, "At time ")

for t = 3:(T_val - K)
    update!(p, t)

    for k = 1:K

        # Initialize marginals
        marginals[:z_pred] = ProbabilityDistribution(Multivariate, GaussianMeanPrecision, m=params_preds[1][:,t,k], w=params_preds[2][:,:,t,k])

        # Clamp data
        data = Dict(:u_t => input[t+k],
                    :z_tmin1 => output_val[t+k-1:-1:t+k-2],
                    :θ => params_θ[1][:,end],
                    :η => params_η[1][end],
                    :γ => params_γ[1][end]/params_γ[2][end])

        # Iterate variational parameter updates
        for i = 1:num_iterations
            
            # Make prediction
            stepz_pred!(data, marginals)
            
            # Compute free energy
            free_energy_pred[t, i] = freeEnergy(data, marginals)
        end

        # Store current parameterizations of marginals
        params_preds[1][:,t,k+1] = unsafeMean(marginals[:z_pred])
        params_preds[2][:,:,t,k+1] = marginals[:z_pred].params[:w]

    end
end
At time 100%|███████████████████████████████████████████| Time: 0:00:19
In [61]:
# Store predictions for later comparisons
results_NLARX = Dict()
results_NLARX["preds"] = params_preds
results_NLARX["params_z"] = params_z;
results_NLARX["params_θ"] = params_θ;
results_NLARX["params_η"] = params_η;
results_NLARX["params_γ"] = params_γ;
results_NLARX["params_ξ"] = params_ξ;
results_NLARX["FE_pred"] = free_energy_pred;
results_NLARX["FE_trn"] = free_energy_trn;

Visualize results

In [62]:
# Mean and std dev of predictions
predictions_mean = params_preds[1][1,:,K+1]
predictions_std = sqrt.(inv.(params_preds[2][1,1,:,K+1])) 

# Subsample for visualization
ss = 10
# viz_ix = 50000:ss:60000
viz_ix = 1024:ss:40000

# Plot predictions
p23 = scatter(viz_ix, output[viz_ix], label="observations", xlabel="time (t)", ylims=[-1., 1.])
plot!(viz_ix, predictions_mean[viz_ix], ribbon=[predictions_std[viz_ix], predictions_std[viz_ix]], label="predictions")
Out[62]:
In [63]:
Plots.savefig(p23, "figures/predictions_nlarx.png")
In [64]:
# Compute prediction error
pred_error = (predictions_mean[2:end] .- output_val[3:end]).^2

# Subsample for visualization
ss = 10
viz_ix = 1:ss:40000

# Scatter error over time
p24 = scatter(viz_ix, pred_error[viz_ix], color="black", xlabel="time (t)", ylabel="Prediction error", label="", yscale=:log10)
Out[64]:
In [65]:
Plots.savefig(p24, "figures/pred-error_nlarx.png")
In [71]:
# Subsample for visualization
ss = 10
viz_ix = 1:ss:40000

# Scatter error over time
p24 = plot(viz_ix, free_energy_trn[viz_ix,end], color="black", xlabel="time (t)", ylabel="F[q]", label="", title="Free energy at training time")
Out[71]:
In [72]:
# Subsample for visualization
ss = 10
viz_ix = 1:ss:40000

# Scatter error over time
p24 = plot(viz_ix, free_energy_pred[viz_ix,end], color="black", xlabel="time (t)", ylabel="F[q]", label="", title="Free energy of predictions")
Out[72]:

Baseline: linear autoregression

In [74]:
# System identification graph
graph3 = FactorGraph()

# Static parameters
@RV θ ~ GaussianMeanPrecision(placeholder(:m_θ, dims=(2,)), placeholder(:w_θ, dims=(2,2)))
@RV η ~ GaussianMeanPrecision(placeholder(:m_η), placeholder(:w_η))
@RV γ ~ Gamma(placeholder(:a_γ), placeholder(:b_γ))
@RV ξ ~ Gamma(placeholder(:a_ξ), placeholder(:b_ξ))

# Linear autoregression function
g(θ, x) = θ[1]*x[1] + θ[2]*x[2]

# State prior
@RV z_tmin1 ~ GaussianMeanPrecision(placeholder(:m_z, dims=(2,)), placeholder(:w_z, dims=(2, 2)), id=:z_tmin1)

# Autoregressive node
@RV z_t ~ NLatentAutoregressiveX(θ, z_tmin1, η, placeholder(:u_t), γ, g=g, id=:z_t)

# Specify likelihood
@RV y_t ~ GaussianMeanPrecision(dot([1. , 0.], z_t), ξ, id=:y_t)

# Placeholder for observation
placeholder(y_t, :y_t)

# Specify recognition model
q3 = PosteriorFactorization(z_t, z_tmin1, θ, η, γ, ξ, ids=[:z_t, :z_tmin1, :θ, :η, :γ, :ξ])
algo3 = variationalAlgorithm(q3, free_energy=true)

# Compile inference algorithm
source_code3 = algorithmSourceCode(algo3, free_energy=true)
eval(Meta.parse(source_code3));
# println(source_code3)
In [75]:
# Inference parameters
num_iterations = 10

# Initialize marginal distribution and observed data dictionaries
data = Dict()
marginals = Dict()

# Initialize free energy tracking array
free_energy_trn = zeros(T_trn, num_iterations)

# Initialize arrays of parameterizations
params_z = (zeros(2,T_trn+1), repeat(.1 .*float(eye(2)), outer=(1,1,T_trn+1)))
params_θ = (ones(2,T_trn+1), repeat(.1 .*float(eye(2)), outer=(1,1,T_trn+1)))
params_η = (2*ones(1,T_trn+1), 1e2 *ones(1,T_trn+1))
params_γ = (1e8*ones(1,T_trn+1), 1e3*ones(1,T_trn+1))
params_ξ = (1e8*ones(1,T_trn+1), 1e1*ones(1,T_trn+1))

# Start progress bar
p = Progress(T_trn, 1, "At time ")

# Perform inference at each time-step
for t = 1:T_trn

    # Update progress bar
    update!(p, t)

    # Initialize marginals
    marginals[:z_tmin1] = ProbabilityDistribution(Multivariate, GaussianMeanPrecision, m=params_z[1][:,t], w=params_z[2][:,:,t])
    marginals[:z_t] = ProbabilityDistribution(Multivariate, GaussianMeanPrecision, m=params_z[1][:,t], w=params_z[2][:,:,t])
    marginals[:θ] = ProbabilityDistribution(Multivariate, GaussianMeanPrecision, m=params_θ[1][:,t], w=params_θ[2][:,:,t])
    marginals[:η] = ProbabilityDistribution(Univariate, GaussianMeanPrecision, m=params_η[1][1,t], w=params_η[2][1,t])
    marginals[:γ] = ProbabilityDistribution(Univariate, Gamma, a=params_γ[1][1,t], b=params_γ[2][1,t])
    marginals[:ξ] = ProbabilityDistribution(Univariate, Gamma, a=params_ξ[1][1,t], b=params_ξ[2][1,t])
    
    data = Dict(:y_t => output_trn[t],
                :u_t => input_trn[t],
                :m_z => params_z[1][:,t],
                :w_z => params_z[2][:,:,t],
                :m_θ => params_θ[1][:,t],
                :w_θ => params_θ[2][:,:,t],
                :m_η => params_η[1][1,t],
                :w_η => params_η[2][1,t],
                :a_γ => params_γ[1][1,t],
                :b_γ => params_γ[2][1,t],
                :a_ξ => params_ξ[1][1,t],
                :b_ξ => params_ξ[2][1,t])

    # Iterate variational parameter updates
    for i = 1:num_iterations

        # Update parameters
        stepη!(data, marginals)
        stepθ!(data, marginals)
        
        # Update states
        stepz_t!(data, marginals)
        stepz_tmin1!(data, marginals)
        
        # Update noise
        stepγ!(data, marginals)
        stepξ!(data, marginals)
        
        # Compute free energy
        free_energy_trn[t, i] = freeEnergy(data, marginals)
        
    end

    # Store current parameterizations of marginals
    params_z[1][:,t+1] = unsafeMean(marginals[:z_t])
    params_z[2][:,:,t+1] = marginals[:z_t].params[:w]
    params_θ[1][:,t+1] = unsafeMean(marginals[:θ])
    params_θ[2][:,:,t+1] = marginals[:θ].params[:w]
    params_η[1][1,t+1] = unsafeMean(marginals[:η])
    params_η[2][1,t+1] = marginals[:η].params[:w]
    params_γ[1][1,t+1] = marginals[:γ].params[:a]
    params_γ[2][1,t+1] = marginals[:γ].params[:b]
    params_ξ[1][1,t+1] = marginals[:ξ].params[:a]
    params_ξ[2][1,t+1] = marginals[:ξ].params[:b]

end
At time 100%|███████████████████████████████████████████| Time: 0:05:49
In [76]:
# Prediction graph
graph4 = FactorGraph()

# Autoregressive node
@RV z_pred ~ NLatentAutoregressiveX(placeholder(:θ, dims=(3,)), placeholder(:z_tmin1, dims=(2,)), placeholder(:η), placeholder(:u_t), placeholder(:γ), g=g, id=:z_pred_t)

# Inference algorithm
q4 = PosteriorFactorization(z_pred, ids=[:z_pred])
algo4 = variationalAlgorithm(q4, free_energy=true)
source_code4 = algorithmSourceCode(algo4, free_energy=true)
eval(Meta.parse(source_code4));
In [77]:
# Number of steps ahead
K = 1

# Initialize free energy tracking array
free_energy_pred = zeros(T_val, num_iterations)

# Initialize future state arrays
params_preds = (zeros(2, T_val - K, K+1), repeat(.1 .*float(eye(2)), outer=(1,1,T_val - K, K+1)))

# Start progress bar
p = Progress(T_val - K, 1, "At time ")

for t = 3:(T_val - K)
    update!(p, t)

    for k = 1:K

        # Initialize marginals
        marginals[:z_pred] = ProbabilityDistribution(Multivariate, GaussianMeanPrecision, m=params_preds[1][:,t,k], w=params_preds[2][:,:,t,k])

        # Clamp data
        data = Dict(:u_t => input[t+k],
                    :z_tmin1 => output_val[t+k-1:-1:t+k-2],
                    :θ => params_θ[1][:,end],
                    :η => params_η[1][end],
                    :γ => params_γ[1][end]/params_γ[2][end])

        # Iterate variational parameter updates
        for i = 1:num_iterations
            
            # Make prediction
            stepz_pred!(data, marginals)
            
            # Compute free energy
            free_energy_pred[t, i] = freeEnergy(data, marginals)
        end

        # Store current parameterizations of marginals
        params_preds[1][:,t,k+1] = unsafeMean(marginals[:z_pred])
        params_preds[2][:,:,t,k+1] = marginals[:z_pred].params[:w]

    end
end
At time 100%|███████████████████████████████████████████| Time: 0:00:22
In [78]:
# Store predictions for later comparisons
results_LARX = Dict()
results_LARX["preds"] = params_preds
results_LARX["params_z"] = params_z;
results_LARX["params_θ"] = params_θ;
results_LARX["params_η"] = params_η;
results_LARX["params_γ"] = params_γ;
results_LARX["params_ξ"] = params_ξ;
results_LARX["FE_pred"] = free_energy_pred;
results_LARX["FE_trn"] = free_energy_trn;

Visualize results

In [81]:
# Mean and std dev of predictions
predictions_mean = params_preds[1][1,:,K+1]
predictions_std = sqrt.(inv.(params_preds[2][1,1,:,K+1])) 

# Subsample for visualization
ss = 10
# viz_ix = 50000:ss:60000
viz_ix = 1024:ss:40000

# Plot predictions
p230 = scatter(viz_ix, output[viz_ix], label="observations", xlabel="time (t)", ylims=[-1., 1.])
plot!(viz_ix, predictions_mean[viz_ix], ribbon=[predictions_std[viz_ix], predictions_std[viz_ix]], label="predictions")
Out[81]:
In [82]:
Plots.savefig(p230, "figures/predictions_larx.png")
In [85]:
# Compute prediction error
pred_error = (predictions_mean[2:end] .- output_val[3:end]).^2

# Subsample for visualization
ss = 10
viz_ix = 1:ss:40000

# Scatter error over time
p240 = scatter(viz_ix, pred_error[viz_ix], color="black", xlabel="time (t)", ylabel="Prediction error", label="", yscale=:log10)
Out[85]:
In [86]:
Plots.savefig(p240, "figures/pred-error_larx.png")
In [87]:
# Subsample for visualization
ss = 10
viz_ix = 1:ss:40000

# Scatter error over time
p24 = plot(viz_ix, free_energy_trn[viz_ix,end], color="black", xlabel="time (t)", ylabel="F[q]", label="", title="Free energy at training time")
Out[87]:
In [88]:
# Subsample for visualization
ss = 10
viz_ix = 1:ss:40000

# Scatter error over time
p24 = plot(viz_ix, free_energy_pred[viz_ix,end], color="black", xlabel="time (t)", ylabel="F[q]", label="", title="Free energy of predictions")
Out[88]:

Baseline: offline neural network state-space model

As baseline, a state-space model with canonical observability of the form:

$$\begin{align} x_{t+1} =&\ Ax_t + Bu_t + Ke_t \\ y_t =&\ sx_t + e_t \, , \end{align}$$

was estimated. A nonlinear least-squares method was used with BFGS as optimizer. T_valhis model was trained in Matlab using the System Identification T_valoolbox. T_valhe results are imported here for visualization purposes.

In [94]:
using MAT
results_SSCANON = matread("results/results_sscanon.mat")
Out[94]:
Dict{String,Any} with 2 entries:
  "pred_states" => [0.00627981; 0.00810739; … ; -0.0516563; 0.00107345]
  "pred_error"  => [9.72186e-6; 2.36542e-5; … ; 0.00258957; 0.00499599]
In [96]:
# Subsample for visualization
ss = 10
viz_ix = 1:ss:40000

# Plot predictions
p402 = scatter(viz_ix, output[viz_ix], label="observations", ylims=[-1., 1.])
plot!(viz_ix, results_SSCANON["pred_states"][viz_ix], label="predictions", xlabel="time (t)")
Out[96]:
In [100]:
# Subsample for visualization
ss = 10
viz_ix = 1:ss:40000

# Scaker error over time
p403 = scatter(viz_ix, results_SSCANON["pred_error"][viz_ix], color="black", xlabel="time (t)", ylabel="Prediction error", label="", yscale=:log10)
Out[100]:

Comparison

In [99]:
# Compute MSE for each step ahead
MSE_LARX = mean((results_LARX["preds"][1][1,:,K+1] .- output_val[2:end]).^2)
MSE_NLARX = mean((results_NLARX["preds"][1][1,:,K+1] .- output_val[2:end]).^2)
MSE_SSCANON = mean(results_SSCANON["pred_error"])

println("MSE LARX = "*string(MSE_LARX))
println("MSE NLARX = "*string(MSE_NLARX))
println("MSE SSCANON = "*string(MSE_SSCANON))
MSE LARX = 5.944884328986888e-5
MSE NLARX = 5.830771054107066e-5
MSE SSCANON = 0.0031517145146173647